#include "grain_growth.h"
#include "grain_growth_test_tool.h"
#include <chrono>

bool debug_on = true;

int main(int argc, char **argv) {
    
    int seed;
    char* run_idx;

    if (argc > 2) {
        run_idx = argv[1];
        seed = std::stoi(argv[2]);
    }
    else {
        run_idx = "";
        seed = 4321;
    }

    /*
    ground truth param:
    A = 1.0
    B = 1.0
    L = 5.0     # the mobil in original code
    kappa = 0.1  # the grcoef in original code
    */
    valueType init_L = 2.0;
    valueType init_A = 2.0;
    valueType init_B = 3.0;
    valueType init_kappa = 0.9;

    double lr = 1e-4;
    int start_skip = 1;
    int skip_step = 5;
    int epoch = 10;

    char* data_path = "../data/grain_growth_all_data";
    GrainGrowthDataset dataset(data_path, start_skip, skip_step);

    int n_grains = dataset.n_grains;

    if (debug_on) {
        std::cout << "finish data loading" << std::endl;
    }

    double min_loss = 1000.0;

    for (int i = 0; i < epoch; ++ i) {
        double loss = 0.0;
        int total_size = 0;
        printf("epoch:\t%d\n", i);

        for (int index = start_skip; index < start_skip + 1; ++index) {
            ReturnItem rt = dataset.get_item(index);

            valueType* eta1_eta2_start = rt.data.eta1_eta2;
            valueType lshr = 0.01;
            uint lshK = 3;
            uint lshL = 10;
            int img_size = dataset.Nx;
            valueType h = 0.5;
            valueType dtime = 0.05;
            valueType ttime = 0.0;
            uint eta1_eta2_len = img_size * img_size * n_grains;

            if (debug_on) {
                std::cout << "sum of eta1_eta2_start: " << sum_mtx(eta1_eta2_start, eta1_eta2_len) << std::endl;
            }

            GrainGrowthOneStep one_step(img_size, img_size, n_grains, lshK, lshL, h,\
                                     init_A, init_B, init_L, init_kappa, dtime, lshr);
            one_step.encode_from_img(eta1_eta2_start);

            auto start = std::chrono::high_resolution_clock::now();
            for (int j = 0; j < skip_step; ++j) {
                one_step.next();
                std::cout << "sim step: " << j << std::endl;
            }
            auto stop = std::chrono::high_resolution_clock::now();
            auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
            std::cout << "time of ts model forward: " << duration.count() << "ms in " << skip_step << "steps" << std::endl;

            if (debug_on) {
                std::cout << "success forward in ts model" << std::endl;
            }

            valueType* eta1_eta2_sim = one_step.decode_to_img();

            if (debug_on) {
                std::cout << "sum of eta1_eta2_sim: " << sum_mtx(eta1_eta2_sim, eta1_eta2_len) << std::endl;
            }

            valueType* eta1_eta2_ref = rt.ref.eta1_eta2_ref;

            if (debug_on) {
                std::cout << "sum of eta1_eta2_ref: " << sum_mtx(eta1_eta2_ref, eta1_eta2_len) << std::endl;
            }

            valueType dloss[eta1_eta2_len];

            calculate_mse_loss(eta1_eta2_sim, eta1_eta2_ref, dloss, eta1_eta2_len);

            if (debug_on) {
                std::cout << "sum of dloss: " << sum_mtx(dloss, eta1_eta2_len) << std::endl;
            }

            if (debug_on) {
                std::cout << "success get loss" << std::endl;
            }

            valueType lshr_back = 0.01;
            uint lshK_back = 3;
            uint lshL_back = 10;

            GrainGrowthOneBack one_back(img_size, img_size, n_grains, lshK_back, lshL_back, h, \
                                         init_A, init_B, init_L, init_kappa, dtime, lshr_back);

            one_back.encode_from_img(eta1_eta2_sim, dloss);

            auto start_back = std::chrono::high_resolution_clock::now();
            for (int j = 0; j < skip_step; ++j) {
                // auto start = std::chrono::high_resolution_clock::now();
                one_back.next();
                // auto stop = std::chrono::high_resolution_clock::now();
                // auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
                // std::cout << "time of one forward: " << duration.count() << std::endl;
                std::cout << "sim back step: " << j << std::endl;
            }
            auto stop_back = std::chrono::high_resolution_clock::now();
            auto duration_back = std::chrono::duration_cast<std::chrono::milliseconds>(stop_back - start_back);
            std::cout << "time of ts model backward: " << duration_back.count() << "ms in " << skip_step << "steps" << std::endl;

            valueType* tsm_back_grad = one_back.decode_derivative(); // L A B kappa

            if (debug_on) {
                std::cout << "success batch loss backward" << std::endl;
            }

            init_L -= lr * tsm_back_grad[0];
            init_A -= lr * tsm_back_grad[1];
            init_B -= lr * tsm_back_grad[2];
            init_kappa -= lr * tsm_back_grad[3];

            if (debug_on) {
                std::cout << "success opt1 opt2 step()" << std::endl;
            }

            int this_size = 1;
            valueType batch_loss = sum_mse_loss(eta1_eta2_sim, eta1_eta2_ref, img_size * img_size * n_grains);
            loss += batch_loss;
            if (true) {
                std::cout << "--------------loss-----------------" << std::endl;
                std::cout << "batch loss: " << batch_loss << std::endl;
                std::cout << "--------------grad-----------------" << std::endl;
                one_back.print_derivative();
                std::cout << "--------------param----------------" << std::endl;
                std::cout << init_L<< std::endl;
                std::cout << init_A << std::endl;
                std::cout << init_B << std::endl;
                std::cout << init_kappa << std::endl;
            }
            total_size += this_size;
        }
    }

}
